{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# D2Go Beginner's Tutorial\n",
    "\n",
    "This is beginner tutorial for [d2go project](https://github.com/facebookresearch/d2go). We will go through some basic usage of d2go, including:\n",
    "- Run inference on images or videos, with a pretrained d2go model\n",
    "- Load a new dataset and train a d2go model\n",
    "- Export models to int8 using post-training quantization. \n",
    "\n",
    "Please install d2go before running this tutorial following the [instructions](https://github.com/facebookresearch/d2go)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference with Pre-trained Models\n",
    "\n",
    "In this section, we will show how to load pretrained models using d2go model_zoo API, and how to make predictions with d2go models and visualize the output. \n",
    "\n",
    "- First import the model zoo API from d2go and get a pretrained Faster R-CNN model with FBNetV3 backbone"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/snowytiger/.local/share/virtualenvs/d2go-x9zKx9Ui/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "INFO:d2go.modeling.backbone.fbnet_v2:Using un-unified arch_def for ARCH \"FBNetV3_A\" (without scaling):\n",
      "trunk\n",
      "- [('conv_k3', 16, 2, 1), ('ir_k3', 16, 1, 2, {'expansion': 1}, {'less_se_channels': False})]\n",
      "- [('ir_k5', 24, 2, 1, {'expansion': 4}, {'less_se_channels': False}), ('ir_k5', 24, 1, 3, {'expansion': 3}, {'less_se_channels': False})]\n",
      "- [('ir_k5_se', 32, 2, 1, {'expansion': 4}, {'less_se_channels': False}), ('ir_k3_se', 32, 1, 3, {'expansion': 3}, {'less_se_channels': False})]\n",
      "- [('ir_k5', 64, 2, 1, {'expansion': 4}, {'less_se_channels': False}), ('ir_k3', 64, 1, 3, {'expansion': 3}, {'less_se_channels': False}), ('ir_k5_se', 112, 1, 1, {'expansion': 4}, {'less_se_channels': False}), ('ir_k5_se', 112, 1, 5, {'expansion': 3}, {'less_se_channels': False})]\n",
      "rpn\n",
      "- [('ir_k5_se', 112, 1, 5, {'expansion': 3}, {'less_se_channels': False})]\n",
      "bbox\n",
      "- [('ir_k5_se', 184, 2, 1, {'expansion': 4}, {'less_se_channels': False}), ('ir_k3_se', 184, 1, 4, {'expansion': 4}, {'less_se_channels': False}), ('ir_k5_se', 200, 1, 1, {'expansion': 6}, {'less_se_channels': False})]\n",
      "mask\n",
      "- [('ir_k3', 128, 2, 1, {'expansion': 4}), ('ir_k3', 128, 1, 2, {'expansion': 6}), ('ir_k3', 128, -2, 1, {'expansion': 6}), ('ir_k3', 64, -2, 1, {'expansion': 3})]\n",
      "basic_args\n",
      "  {'dw_skip_bnrelu': True, 'zero_last_bn_gamma': False}\n",
      "INFO:d2go.modeling.backbone.fbnet_v2:Build FBNet using unified arch_def:\n",
      "trunk\n",
      "- {'block_op': 'conv_k3', 'block_cfg': {'out_channels': 16, 'stride': 2}, 'stage_idx': 0, 'block_idx': 0}\n",
      "- {'block_op': 'ir_k3', 'block_cfg': {'out_channels': 16, 'stride': 1, 'expansion': 1, 'less_se_channels': False}, 'stage_idx': 0, 'block_idx': 1}\n",
      "- {'block_op': 'ir_k3', 'block_cfg': {'out_channels': 16, 'stride': 1, 'expansion': 1, 'less_se_channels': False}, 'stage_idx': 0, 'block_idx': 2}\n",
      "- {'block_op': 'ir_k5', 'block_cfg': {'out_channels': 24, 'stride': 2, 'expansion': 4, 'less_se_channels': False}, 'stage_idx': 1, 'block_idx': 0}\n",
      "- {'block_op': 'ir_k5', 'block_cfg': {'out_channels': 24, 'stride': 1, 'expansion': 3, 'less_se_channels': False}, 'stage_idx': 1, 'block_idx': 1}\n",
      "- {'block_op': 'ir_k5', 'block_cfg': {'out_channels': 24, 'stride': 1, 'expansion': 3, 'less_se_channels': False}, 'stage_idx': 1, 'block_idx': 2}\n",
      "- {'block_op': 'ir_k5', 'block_cfg': {'out_channels': 24, 'stride': 1, 'expansion': 3, 'less_se_channels': False}, 'stage_idx': 1, 'block_idx': 3}\n",
      "- {'block_op': 'ir_k5_se', 'block_cfg': {'out_channels': 32, 'stride': 2, 'expansion': 4, 'less_se_channels': False}, 'stage_idx': 2, 'block_idx': 0}\n",
      "- {'block_op': 'ir_k3_se', 'block_cfg': {'out_channels': 32, 'stride': 1, 'expansion': 3, 'less_se_channels': False}, 'stage_idx': 2, 'block_idx': 1}\n",
      "- {'block_op': 'ir_k3_se', 'block_cfg': {'out_channels': 32, 'stride': 1, 'expansion': 3, 'less_se_channels': False}, 'stage_idx': 2, 'block_idx': 2}\n",
      "- {'block_op': 'ir_k3_se', 'block_cfg': {'out_channels': 32, 'stride': 1, 'expansion': 3, 'less_se_channels': False}, 'stage_idx': 2, 'block_idx': 3}\n",
      "- {'block_op': 'ir_k5', 'block_cfg': {'out_channels': 64, 'stride': 2, 'expansion': 4, 'less_se_channels': False}, 'stage_idx': 3, 'block_idx': 0}\n",
      "- {'block_op': 'ir_k3', 'block_cfg': {'out_channels': 64, 'stride': 1, 'expansion': 3, 'less_se_channels': False}, 'stage_idx': 3, 'block_idx': 1}\n",
      "- {'block_op': 'ir_k3', 'block_cfg': {'out_channels': 64, 'stride': 1, 'expansion': 3, 'less_se_channels': False}, 'stage_idx': 3, 'block_idx': 2}\n",
      "- {'block_op': 'ir_k3', 'block_cfg': {'out_channels': 64, 'stride': 1, 'expansion': 3, 'less_se_channels': False}, 'stage_idx': 3, 'block_idx': 3}\n",
      "- {'block_op': 'ir_k5_se', 'block_cfg': {'out_channels': 112, 'stride': 1, 'expansion': 4, 'less_se_channels': False}, 'stage_idx': 3, 'block_idx': 4}\n",
      "- {'block_op': 'ir_k5_se', 'block_cfg': {'out_channels': 112, 'stride': 1, 'expansion': 3, 'less_se_channels': False}, 'stage_idx': 3, 'block_idx': 5}\n",
      "- {'block_op': 'ir_k5_se', 'block_cfg': {'out_channels': 112, 'stride': 1, 'expansion': 3, 'less_se_channels': False}, 'stage_idx': 3, 'block_idx': 6}\n",
      "- {'block_op': 'ir_k5_se', 'block_cfg': {'out_channels': 112, 'stride': 1, 'expansion': 3, 'less_se_channels': False}, 'stage_idx': 3, 'block_idx': 7}\n",
      "- {'block_op': 'ir_k5_se', 'block_cfg': {'out_channels': 112, 'stride': 1, 'expansion': 3, 'less_se_channels': False}, 'stage_idx': 3, 'block_idx': 8}\n",
      "- {'block_op': 'ir_k5_se', 'block_cfg': {'out_channels': 112, 'stride': 1, 'expansion': 3, 'less_se_channels': False}, 'stage_idx': 3, 'block_idx': 9}\n",
      "WARNING:mobile_cv.arch.utils.helper:Arguments ['width_divisor', 'dw_skip_bnrelu', 'zero_last_bn_gamma'] skipped for op Conv2d\n",
      "INFO:d2go.modeling.backbone.fbnet_v2:Build FBNet using unified arch_def:\n",
      "rpn\n",
      "- {'block_op': 'ir_k5_se', 'block_cfg': {'out_channels': 112, 'stride': 1, 'expansion': 3, 'less_se_channels': False}, 'stage_idx': 0, 'block_idx': 0}\n",
      "- {'block_op': 'ir_k5_se', 'block_cfg': {'out_channels': 112, 'stride': 1, 'expansion': 3, 'less_se_channels': False}, 'stage_idx': 0, 'block_idx': 1}\n",
      "- {'block_op': 'ir_k5_se', 'block_cfg': {'out_channels': 112, 'stride': 1, 'expansion': 3, 'less_se_channels': False}, 'stage_idx': 0, 'block_idx': 2}\n",
      "- {'block_op': 'ir_k5_se', 'block_cfg': {'out_channels': 112, 'stride': 1, 'expansion': 3, 'less_se_channels': False}, 'stage_idx': 0, 'block_idx': 3}\n",
      "- {'block_op': 'ir_k5_se', 'block_cfg': {'out_channels': 112, 'stride': 1, 'expansion': 3, 'less_se_channels': False}, 'stage_idx': 0, 'block_idx': 4}\n",
      "INFO:d2go.modeling.backbone.fbnet_v2:Build FBNet using unified arch_def:\n",
      "bbox\n",
      "- {'block_op': 'ir_k5_se', 'block_cfg': {'out_channels': 184, 'stride': 2, 'expansion': 4, 'less_se_channels': False}, 'stage_idx': 0, 'block_idx': 0}\n",
      "- {'block_op': 'ir_k3_se', 'block_cfg': {'out_channels': 184, 'stride': 1, 'expansion': 4, 'less_se_channels': False}, 'stage_idx': 0, 'block_idx': 1}\n",
      "- {'block_op': 'ir_k3_se', 'block_cfg': {'out_channels': 184, 'stride': 1, 'expansion': 4, 'less_se_channels': False}, 'stage_idx': 0, 'block_idx': 2}\n",
      "- {'block_op': 'ir_k3_se', 'block_cfg': {'out_channels': 184, 'stride': 1, 'expansion': 4, 'less_se_channels': False}, 'stage_idx': 0, 'block_idx': 3}\n",
      "- {'block_op': 'ir_k3_se', 'block_cfg': {'out_channels': 184, 'stride': 1, 'expansion': 4, 'less_se_channels': False}, 'stage_idx': 0, 'block_idx': 4}\n",
      "- {'block_op': 'ir_k5_se', 'block_cfg': {'out_channels': 200, 'stride': 1, 'expansion': 6, 'less_se_channels': False}, 'stage_idx': 0, 'block_idx': 5}\n",
      "INFO:d2go.modeling.model_ema:Using Model EMA.\n",
      "INFO:fvcore.common.checkpoint:[Checkpointer] Loading from https://mobile-cv.s3-us-west-2.amazonaws.com/d2go/models/268421013/model_final.pth ...\n",
      "INFO:iopath.common.file_io:URL https://mobile-cv.s3-us-west-2.amazonaws.com/d2go/models/268421013/model_final.pth cached in /home/snowytiger/.torch/iopath_cache/d2go/models/268421013/model_final.pth\n",
      "INFO:detectron2.checkpoint.c2_model_loading:Following weights matched with model:\n",
      "| Names in Model                                                        | Names in Checkpoint                                                                                                     | Shapes                             |\n",
      "|:----------------------------------------------------------------------|:------------------------------------------------------------------------------------------------------------------------|:-----------------------------------|\n",
      "| backbone.body.trunk0.fbnetv2_0_0.bn.*                                 | backbone.body.trunk0.fbnetv2_0_0.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                          | (16,) () (16,) (16,) (16,)         |\n",
      "| backbone.body.trunk0.fbnetv2_0_0.conv.*                               | backbone.body.trunk0.fbnetv2_0_0.conv.{bias,weight}                                                                     | (16,) (16,3,3,3)                   |\n",
      "| backbone.body.trunk0.fbnetv2_0_1.dw.conv.weight                       | backbone.body.trunk0.fbnetv2_0_1.dw.conv.weight                                                                         | (16, 1, 3, 3)                      |\n",
      "| backbone.body.trunk0.fbnetv2_0_1.pwl.bn.*                             | backbone.body.trunk0.fbnetv2_0_1.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                      | (16,) () (16,) (16,) (16,)         |\n",
      "| backbone.body.trunk0.fbnetv2_0_1.pwl.conv.weight                      | backbone.body.trunk0.fbnetv2_0_1.pwl.conv.weight                                                                        | (16, 16, 1, 1)                     |\n",
      "| backbone.body.trunk0.fbnetv2_0_2.dw.conv.weight                       | backbone.body.trunk0.fbnetv2_0_2.dw.conv.weight                                                                         | (16, 1, 3, 3)                      |\n",
      "| backbone.body.trunk0.fbnetv2_0_2.pwl.bn.*                             | backbone.body.trunk0.fbnetv2_0_2.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                      | (16,) () (16,) (16,) (16,)         |\n",
      "| backbone.body.trunk0.fbnetv2_0_2.pwl.conv.weight                      | backbone.body.trunk0.fbnetv2_0_2.pwl.conv.weight                                                                        | (16, 16, 1, 1)                     |\n",
      "| backbone.body.trunk1.fbnetv2_1_0.dw.conv.weight                       | backbone.body.trunk1.fbnetv2_1_0.dw.conv.weight                                                                         | (64, 1, 5, 5)                      |\n",
      "| backbone.body.trunk1.fbnetv2_1_0.pw.bn.*                              | backbone.body.trunk1.fbnetv2_1_0.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                       | (64,) () (64,) (64,) (64,)         |\n",
      "| backbone.body.trunk1.fbnetv2_1_0.pw.conv.weight                       | backbone.body.trunk1.fbnetv2_1_0.pw.conv.weight                                                                         | (64, 16, 1, 1)                     |\n",
      "| backbone.body.trunk1.fbnetv2_1_0.pwl.bn.*                             | backbone.body.trunk1.fbnetv2_1_0.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                      | (24,) () (24,) (24,) (24,)         |\n",
      "| backbone.body.trunk1.fbnetv2_1_0.pwl.conv.weight                      | backbone.body.trunk1.fbnetv2_1_0.pwl.conv.weight                                                                        | (24, 64, 1, 1)                     |\n",
      "| backbone.body.trunk1.fbnetv2_1_1.dw.conv.weight                       | backbone.body.trunk1.fbnetv2_1_1.dw.conv.weight                                                                         | (72, 1, 5, 5)                      |\n",
      "| backbone.body.trunk1.fbnetv2_1_1.pw.bn.*                              | backbone.body.trunk1.fbnetv2_1_1.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                       | (72,) () (72,) (72,) (72,)         |\n",
      "| backbone.body.trunk1.fbnetv2_1_1.pw.conv.weight                       | backbone.body.trunk1.fbnetv2_1_1.pw.conv.weight                                                                         | (72, 24, 1, 1)                     |\n",
      "| backbone.body.trunk1.fbnetv2_1_1.pwl.bn.*                             | backbone.body.trunk1.fbnetv2_1_1.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                      | (24,) () (24,) (24,) (24,)         |\n",
      "| backbone.body.trunk1.fbnetv2_1_1.pwl.conv.weight                      | backbone.body.trunk1.fbnetv2_1_1.pwl.conv.weight                                                                        | (24, 72, 1, 1)                     |\n",
      "| backbone.body.trunk1.fbnetv2_1_2.dw.conv.weight                       | backbone.body.trunk1.fbnetv2_1_2.dw.conv.weight                                                                         | (72, 1, 5, 5)                      |\n",
      "| backbone.body.trunk1.fbnetv2_1_2.pw.bn.*                              | backbone.body.trunk1.fbnetv2_1_2.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                       | (72,) () (72,) (72,) (72,)         |\n",
      "| backbone.body.trunk1.fbnetv2_1_2.pw.conv.weight                       | backbone.body.trunk1.fbnetv2_1_2.pw.conv.weight                                                                         | (72, 24, 1, 1)                     |\n",
      "| backbone.body.trunk1.fbnetv2_1_2.pwl.bn.*                             | backbone.body.trunk1.fbnetv2_1_2.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                      | (24,) () (24,) (24,) (24,)         |\n",
      "| backbone.body.trunk1.fbnetv2_1_2.pwl.conv.weight                      | backbone.body.trunk1.fbnetv2_1_2.pwl.conv.weight                                                                        | (24, 72, 1, 1)                     |\n",
      "| backbone.body.trunk1.fbnetv2_1_3.dw.conv.weight                       | backbone.body.trunk1.fbnetv2_1_3.dw.conv.weight                                                                         | (72, 1, 5, 5)                      |\n",
      "| backbone.body.trunk1.fbnetv2_1_3.pw.bn.*                              | backbone.body.trunk1.fbnetv2_1_3.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                       | (72,) () (72,) (72,) (72,)         |\n",
      "| backbone.body.trunk1.fbnetv2_1_3.pw.conv.weight                       | backbone.body.trunk1.fbnetv2_1_3.pw.conv.weight                                                                         | (72, 24, 1, 1)                     |\n",
      "| backbone.body.trunk1.fbnetv2_1_3.pwl.bn.*                             | backbone.body.trunk1.fbnetv2_1_3.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                      | (24,) () (24,) (24,) (24,)         |\n",
      "| backbone.body.trunk1.fbnetv2_1_3.pwl.conv.weight                      | backbone.body.trunk1.fbnetv2_1_3.pwl.conv.weight                                                                        | (24, 72, 1, 1)                     |\n",
      "| backbone.body.trunk2.fbnetv2_2_0.dw.conv.weight                       | backbone.body.trunk2.fbnetv2_2_0.dw.conv.weight                                                                         | (96, 1, 5, 5)                      |\n",
      "| backbone.body.trunk2.fbnetv2_2_0.pw.bn.*                              | backbone.body.trunk2.fbnetv2_2_0.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                       | (96,) () (96,) (96,) (96,)         |\n",
      "| backbone.body.trunk2.fbnetv2_2_0.pw.conv.weight                       | backbone.body.trunk2.fbnetv2_2_0.pw.conv.weight                                                                         | (96, 24, 1, 1)                     |\n",
      "| backbone.body.trunk2.fbnetv2_2_0.pwl.bn.*                             | backbone.body.trunk2.fbnetv2_2_0.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                      | (32,) () (32,) (32,) (32,)         |\n",
      "| backbone.body.trunk2.fbnetv2_2_0.pwl.conv.weight                      | backbone.body.trunk2.fbnetv2_2_0.pwl.conv.weight                                                                        | (32, 96, 1, 1)                     |\n",
      "| backbone.body.trunk2.fbnetv2_2_0.se.se.0.conv.*                       | backbone.body.trunk2.fbnetv2_2_0.se.se.0.conv.{bias,weight}                                                             | (24,) (24,96,1,1)                  |\n",
      "| backbone.body.trunk2.fbnetv2_2_0.se.se.1.*                            | backbone.body.trunk2.fbnetv2_2_0.se.se.1.{bias,weight}                                                                  | (96,) (96,24,1,1)                  |\n",
      "| backbone.body.trunk2.fbnetv2_2_1.dw.conv.weight                       | backbone.body.trunk2.fbnetv2_2_1.dw.conv.weight                                                                         | (96, 1, 3, 3)                      |\n",
      "| backbone.body.trunk2.fbnetv2_2_1.pw.bn.*                              | backbone.body.trunk2.fbnetv2_2_1.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                       | (96,) () (96,) (96,) (96,)         |\n",
      "| backbone.body.trunk2.fbnetv2_2_1.pw.conv.weight                       | backbone.body.trunk2.fbnetv2_2_1.pw.conv.weight                                                                         | (96, 32, 1, 1)                     |\n",
      "| backbone.body.trunk2.fbnetv2_2_1.pwl.bn.*                             | backbone.body.trunk2.fbnetv2_2_1.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                      | (32,) () (32,) (32,) (32,)         |\n",
      "| backbone.body.trunk2.fbnetv2_2_1.pwl.conv.weight                      | backbone.body.trunk2.fbnetv2_2_1.pwl.conv.weight                                                                        | (32, 96, 1, 1)                     |\n",
      "| backbone.body.trunk2.fbnetv2_2_1.se.se.0.conv.*                       | backbone.body.trunk2.fbnetv2_2_1.se.se.0.conv.{bias,weight}                                                             | (24,) (24,96,1,1)                  |\n",
      "| backbone.body.trunk2.fbnetv2_2_1.se.se.1.*                            | backbone.body.trunk2.fbnetv2_2_1.se.se.1.{bias,weight}                                                                  | (96,) (96,24,1,1)                  |\n",
      "| backbone.body.trunk2.fbnetv2_2_2.dw.conv.weight                       | backbone.body.trunk2.fbnetv2_2_2.dw.conv.weight                                                                         | (96, 1, 3, 3)                      |\n",
      "| backbone.body.trunk2.fbnetv2_2_2.pw.bn.*                              | backbone.body.trunk2.fbnetv2_2_2.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                       | (96,) () (96,) (96,) (96,)         |\n",
      "| backbone.body.trunk2.fbnetv2_2_2.pw.conv.weight                       | backbone.body.trunk2.fbnetv2_2_2.pw.conv.weight                                                                         | (96, 32, 1, 1)                     |\n",
      "| backbone.body.trunk2.fbnetv2_2_2.pwl.bn.*                             | backbone.body.trunk2.fbnetv2_2_2.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                      | (32,) () (32,) (32,) (32,)         |\n",
      "| backbone.body.trunk2.fbnetv2_2_2.pwl.conv.weight                      | backbone.body.trunk2.fbnetv2_2_2.pwl.conv.weight                                                                        | (32, 96, 1, 1)                     |\n",
      "| backbone.body.trunk2.fbnetv2_2_2.se.se.0.conv.*                       | backbone.body.trunk2.fbnetv2_2_2.se.se.0.conv.{bias,weight}                                                             | (24,) (24,96,1,1)                  |\n",
      "| backbone.body.trunk2.fbnetv2_2_2.se.se.1.*                            | backbone.body.trunk2.fbnetv2_2_2.se.se.1.{bias,weight}                                                                  | (96,) (96,24,1,1)                  |\n",
      "| backbone.body.trunk2.fbnetv2_2_3.dw.conv.weight                       | backbone.body.trunk2.fbnetv2_2_3.dw.conv.weight                                                                         | (96, 1, 3, 3)                      |\n",
      "| backbone.body.trunk2.fbnetv2_2_3.pw.bn.*                              | backbone.body.trunk2.fbnetv2_2_3.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                       | (96,) () (96,) (96,) (96,)         |\n",
      "| backbone.body.trunk2.fbnetv2_2_3.pw.conv.weight                       | backbone.body.trunk2.fbnetv2_2_3.pw.conv.weight                                                                         | (96, 32, 1, 1)                     |\n",
      "| backbone.body.trunk2.fbnetv2_2_3.pwl.bn.*                             | backbone.body.trunk2.fbnetv2_2_3.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                      | (32,) () (32,) (32,) (32,)         |\n",
      "| backbone.body.trunk2.fbnetv2_2_3.pwl.conv.weight                      | backbone.body.trunk2.fbnetv2_2_3.pwl.conv.weight                                                                        | (32, 96, 1, 1)                     |\n",
      "| backbone.body.trunk2.fbnetv2_2_3.se.se.0.conv.*                       | backbone.body.trunk2.fbnetv2_2_3.se.se.0.conv.{bias,weight}                                                             | (24,) (24,96,1,1)                  |\n",
      "| backbone.body.trunk2.fbnetv2_2_3.se.se.1.*                            | backbone.body.trunk2.fbnetv2_2_3.se.se.1.{bias,weight}                                                                  | (96,) (96,24,1,1)                  |\n",
      "| backbone.body.trunk3.fbnetv2_3_0.dw.conv.weight                       | backbone.body.trunk3.fbnetv2_3_0.dw.conv.weight                                                                         | (128, 1, 5, 5)                     |\n",
      "| backbone.body.trunk3.fbnetv2_3_0.pw.bn.*                              | backbone.body.trunk3.fbnetv2_3_0.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                       | (128,) () (128,) (128,) (128,)     |\n",
      "| backbone.body.trunk3.fbnetv2_3_0.pw.conv.weight                       | backbone.body.trunk3.fbnetv2_3_0.pw.conv.weight                                                                         | (128, 32, 1, 1)                    |\n",
      "| backbone.body.trunk3.fbnetv2_3_0.pwl.bn.*                             | backbone.body.trunk3.fbnetv2_3_0.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                      | (64,) () (64,) (64,) (64,)         |\n",
      "| backbone.body.trunk3.fbnetv2_3_0.pwl.conv.weight                      | backbone.body.trunk3.fbnetv2_3_0.pwl.conv.weight                                                                        | (64, 128, 1, 1)                    |\n",
      "| backbone.body.trunk3.fbnetv2_3_1.dw.conv.weight                       | backbone.body.trunk3.fbnetv2_3_1.dw.conv.weight                                                                         | (192, 1, 3, 3)                     |\n",
      "| backbone.body.trunk3.fbnetv2_3_1.pw.bn.*                              | backbone.body.trunk3.fbnetv2_3_1.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                       | (192,) () (192,) (192,) (192,)     |\n",
      "| backbone.body.trunk3.fbnetv2_3_1.pw.conv.weight                       | backbone.body.trunk3.fbnetv2_3_1.pw.conv.weight                                                                         | (192, 64, 1, 1)                    |\n",
      "| backbone.body.trunk3.fbnetv2_3_1.pwl.bn.*                             | backbone.body.trunk3.fbnetv2_3_1.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                      | (64,) () (64,) (64,) (64,)         |\n",
      "| backbone.body.trunk3.fbnetv2_3_1.pwl.conv.weight                      | backbone.body.trunk3.fbnetv2_3_1.pwl.conv.weight                                                                        | (64, 192, 1, 1)                    |\n",
      "| backbone.body.trunk3.fbnetv2_3_2.dw.conv.weight                       | backbone.body.trunk3.fbnetv2_3_2.dw.conv.weight                                                                         | (192, 1, 3, 3)                     |\n",
      "| backbone.body.trunk3.fbnetv2_3_2.pw.bn.*                              | backbone.body.trunk3.fbnetv2_3_2.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                       | (192,) () (192,) (192,) (192,)     |\n",
      "| backbone.body.trunk3.fbnetv2_3_2.pw.conv.weight                       | backbone.body.trunk3.fbnetv2_3_2.pw.conv.weight                                                                         | (192, 64, 1, 1)                    |\n",
      "| backbone.body.trunk3.fbnetv2_3_2.pwl.bn.*                             | backbone.body.trunk3.fbnetv2_3_2.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                      | (64,) () (64,) (64,) (64,)         |\n",
      "| backbone.body.trunk3.fbnetv2_3_2.pwl.conv.weight                      | backbone.body.trunk3.fbnetv2_3_2.pwl.conv.weight                                                                        | (64, 192, 1, 1)                    |\n",
      "| backbone.body.trunk3.fbnetv2_3_3.dw.conv.weight                       | backbone.body.trunk3.fbnetv2_3_3.dw.conv.weight                                                                         | (192, 1, 3, 3)                     |\n",
      "| backbone.body.trunk3.fbnetv2_3_3.pw.bn.*                              | backbone.body.trunk3.fbnetv2_3_3.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                       | (192,) () (192,) (192,) (192,)     |\n",
      "| backbone.body.trunk3.fbnetv2_3_3.pw.conv.weight                       | backbone.body.trunk3.fbnetv2_3_3.pw.conv.weight                                                                         | (192, 64, 1, 1)                    |\n",
      "| backbone.body.trunk3.fbnetv2_3_3.pwl.bn.*                             | backbone.body.trunk3.fbnetv2_3_3.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                      | (64,) () (64,) (64,) (64,)         |\n",
      "| backbone.body.trunk3.fbnetv2_3_3.pwl.conv.weight                      | backbone.body.trunk3.fbnetv2_3_3.pwl.conv.weight                                                                        | (64, 192, 1, 1)                    |\n",
      "| backbone.body.trunk3.fbnetv2_3_4.dw.conv.weight                       | backbone.body.trunk3.fbnetv2_3_4.dw.conv.weight                                                                         | (256, 1, 5, 5)                     |\n",
      "| backbone.body.trunk3.fbnetv2_3_4.pw.bn.*                              | backbone.body.trunk3.fbnetv2_3_4.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                       | (256,) () (256,) (256,) (256,)     |\n",
      "| backbone.body.trunk3.fbnetv2_3_4.pw.conv.weight                       | backbone.body.trunk3.fbnetv2_3_4.pw.conv.weight                                                                         | (256, 64, 1, 1)                    |\n",
      "| backbone.body.trunk3.fbnetv2_3_4.pwl.bn.*                             | backbone.body.trunk3.fbnetv2_3_4.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                      | (112,) () (112,) (112,) (112,)     |\n",
      "| backbone.body.trunk3.fbnetv2_3_4.pwl.conv.weight                      | backbone.body.trunk3.fbnetv2_3_4.pwl.conv.weight                                                                        | (112, 256, 1, 1)                   |\n",
      "| backbone.body.trunk3.fbnetv2_3_4.se.se.0.conv.*                       | backbone.body.trunk3.fbnetv2_3_4.se.se.0.conv.{bias,weight}                                                             | (64,) (64,256,1,1)                 |\n",
      "| backbone.body.trunk3.fbnetv2_3_4.se.se.1.*                            | backbone.body.trunk3.fbnetv2_3_4.se.se.1.{bias,weight}                                                                  | (256,) (256,64,1,1)                |\n",
      "| backbone.body.trunk3.fbnetv2_3_5.dw.conv.weight                       | backbone.body.trunk3.fbnetv2_3_5.dw.conv.weight                                                                         | (336, 1, 5, 5)                     |\n",
      "| backbone.body.trunk3.fbnetv2_3_5.pw.bn.*                              | backbone.body.trunk3.fbnetv2_3_5.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                       | (336,) () (336,) (336,) (336,)     |\n",
      "| backbone.body.trunk3.fbnetv2_3_5.pw.conv.weight                       | backbone.body.trunk3.fbnetv2_3_5.pw.conv.weight                                                                         | (336, 112, 1, 1)                   |\n",
      "| backbone.body.trunk3.fbnetv2_3_5.pwl.bn.*                             | backbone.body.trunk3.fbnetv2_3_5.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                      | (112,) () (112,) (112,) (112,)     |\n",
      "| backbone.body.trunk3.fbnetv2_3_5.pwl.conv.weight                      | backbone.body.trunk3.fbnetv2_3_5.pwl.conv.weight                                                                        | (112, 336, 1, 1)                   |\n",
      "| backbone.body.trunk3.fbnetv2_3_5.se.se.0.conv.*                       | backbone.body.trunk3.fbnetv2_3_5.se.se.0.conv.{bias,weight}                                                             | (88,) (88,336,1,1)                 |\n",
      "| backbone.body.trunk3.fbnetv2_3_5.se.se.1.*                            | backbone.body.trunk3.fbnetv2_3_5.se.se.1.{bias,weight}                                                                  | (336,) (336,88,1,1)                |\n",
      "| backbone.body.trunk3.fbnetv2_3_6.dw.conv.weight                       | backbone.body.trunk3.fbnetv2_3_6.dw.conv.weight                                                                         | (336, 1, 5, 5)                     |\n",
      "| backbone.body.trunk3.fbnetv2_3_6.pw.bn.*                              | backbone.body.trunk3.fbnetv2_3_6.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                       | (336,) () (336,) (336,) (336,)     |\n",
      "| backbone.body.trunk3.fbnetv2_3_6.pw.conv.weight                       | backbone.body.trunk3.fbnetv2_3_6.pw.conv.weight                                                                         | (336, 112, 1, 1)                   |\n",
      "| backbone.body.trunk3.fbnetv2_3_6.pwl.bn.*                             | backbone.body.trunk3.fbnetv2_3_6.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                      | (112,) () (112,) (112,) (112,)     |\n",
      "| backbone.body.trunk3.fbnetv2_3_6.pwl.conv.weight                      | backbone.body.trunk3.fbnetv2_3_6.pwl.conv.weight                                                                        | (112, 336, 1, 1)                   |\n",
      "| backbone.body.trunk3.fbnetv2_3_6.se.se.0.conv.*                       | backbone.body.trunk3.fbnetv2_3_6.se.se.0.conv.{bias,weight}                                                             | (88,) (88,336,1,1)                 |\n",
      "| backbone.body.trunk3.fbnetv2_3_6.se.se.1.*                            | backbone.body.trunk3.fbnetv2_3_6.se.se.1.{bias,weight}                                                                  | (336,) (336,88,1,1)                |\n",
      "| backbone.body.trunk3.fbnetv2_3_7.dw.conv.weight                       | backbone.body.trunk3.fbnetv2_3_7.dw.conv.weight                                                                         | (336, 1, 5, 5)                     |\n",
      "| backbone.body.trunk3.fbnetv2_3_7.pw.bn.*                              | backbone.body.trunk3.fbnetv2_3_7.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                       | (336,) () (336,) (336,) (336,)     |\n",
      "| backbone.body.trunk3.fbnetv2_3_7.pw.conv.weight                       | backbone.body.trunk3.fbnetv2_3_7.pw.conv.weight                                                                         | (336, 112, 1, 1)                   |\n",
      "| backbone.body.trunk3.fbnetv2_3_7.pwl.bn.*                             | backbone.body.trunk3.fbnetv2_3_7.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                      | (112,) () (112,) (112,) (112,)     |\n",
      "| backbone.body.trunk3.fbnetv2_3_7.pwl.conv.weight                      | backbone.body.trunk3.fbnetv2_3_7.pwl.conv.weight                                                                        | (112, 336, 1, 1)                   |\n",
      "| backbone.body.trunk3.fbnetv2_3_7.se.se.0.conv.*                       | backbone.body.trunk3.fbnetv2_3_7.se.se.0.conv.{bias,weight}                                                             | (88,) (88,336,1,1)                 |\n",
      "| backbone.body.trunk3.fbnetv2_3_7.se.se.1.*                            | backbone.body.trunk3.fbnetv2_3_7.se.se.1.{bias,weight}                                                                  | (336,) (336,88,1,1)                |\n",
      "| backbone.body.trunk3.fbnetv2_3_8.dw.conv.weight                       | backbone.body.trunk3.fbnetv2_3_8.dw.conv.weight                                                                         | (336, 1, 5, 5)                     |\n",
      "| backbone.body.trunk3.fbnetv2_3_8.pw.bn.*                              | backbone.body.trunk3.fbnetv2_3_8.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                       | (336,) () (336,) (336,) (336,)     |\n",
      "| backbone.body.trunk3.fbnetv2_3_8.pw.conv.weight                       | backbone.body.trunk3.fbnetv2_3_8.pw.conv.weight                                                                         | (336, 112, 1, 1)                   |\n",
      "| backbone.body.trunk3.fbnetv2_3_8.pwl.bn.*                             | backbone.body.trunk3.fbnetv2_3_8.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                      | (112,) () (112,) (112,) (112,)     |\n",
      "| backbone.body.trunk3.fbnetv2_3_8.pwl.conv.weight                      | backbone.body.trunk3.fbnetv2_3_8.pwl.conv.weight                                                                        | (112, 336, 1, 1)                   |\n",
      "| backbone.body.trunk3.fbnetv2_3_8.se.se.0.conv.*                       | backbone.body.trunk3.fbnetv2_3_8.se.se.0.conv.{bias,weight}                                                             | (88,) (88,336,1,1)                 |\n",
      "| backbone.body.trunk3.fbnetv2_3_8.se.se.1.*                            | backbone.body.trunk3.fbnetv2_3_8.se.se.1.{bias,weight}                                                                  | (336,) (336,88,1,1)                |\n",
      "| backbone.body.trunk3.fbnetv2_3_9.dw.conv.weight                       | backbone.body.trunk3.fbnetv2_3_9.dw.conv.weight                                                                         | (336, 1, 5, 5)                     |\n",
      "| backbone.body.trunk3.fbnetv2_3_9.pw.bn.*                              | backbone.body.trunk3.fbnetv2_3_9.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                       | (336,) () (336,) (336,) (336,)     |\n",
      "| backbone.body.trunk3.fbnetv2_3_9.pw.conv.weight                       | backbone.body.trunk3.fbnetv2_3_9.pw.conv.weight                                                                         | (336, 112, 1, 1)                   |\n",
      "| backbone.body.trunk3.fbnetv2_3_9.pwl.bn.*                             | backbone.body.trunk3.fbnetv2_3_9.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}                      | (112,) () (112,) (112,) (112,)     |\n",
      "| backbone.body.trunk3.fbnetv2_3_9.pwl.conv.weight                      | backbone.body.trunk3.fbnetv2_3_9.pwl.conv.weight                                                                        | (112, 336, 1, 1)                   |\n",
      "| backbone.body.trunk3.fbnetv2_3_9.se.se.0.conv.*                       | backbone.body.trunk3.fbnetv2_3_9.se.se.0.conv.{bias,weight}                                                             | (88,) (88,336,1,1)                 |\n",
      "| backbone.body.trunk3.fbnetv2_3_9.se.se.1.*                            | backbone.body.trunk3.fbnetv2_3_9.se.se.1.{bias,weight}                                                                  | (336,) (336,88,1,1)                |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_0.dw.conv.weight  | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_0.dw.conv.weight                                                    | (336, 1, 5, 5)                     |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_0.pw.bn.*         | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_0.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}  | (336,) () (336,) (336,) (336,)     |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_0.pw.conv.weight  | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_0.pw.conv.weight                                                    | (336, 112, 1, 1)                   |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_0.pwl.bn.*        | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_0.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight} | (112,) () (112,) (112,) (112,)     |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_0.pwl.conv.weight | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_0.pwl.conv.weight                                                   | (112, 336, 1, 1)                   |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_0.se.se.0.conv.*  | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_0.se.se.0.conv.{bias,weight}                                        | (88,) (88,336,1,1)                 |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_0.se.se.1.*       | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_0.se.se.1.{bias,weight}                                             | (336,) (336,88,1,1)                |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_1.dw.conv.weight  | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_1.dw.conv.weight                                                    | (336, 1, 5, 5)                     |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_1.pw.bn.*         | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_1.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}  | (336,) () (336,) (336,) (336,)     |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_1.pw.conv.weight  | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_1.pw.conv.weight                                                    | (336, 112, 1, 1)                   |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_1.pwl.bn.*        | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_1.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight} | (112,) () (112,) (112,) (112,)     |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_1.pwl.conv.weight | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_1.pwl.conv.weight                                                   | (112, 336, 1, 1)                   |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_1.se.se.0.conv.*  | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_1.se.se.0.conv.{bias,weight}                                        | (88,) (88,336,1,1)                 |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_1.se.se.1.*       | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_1.se.se.1.{bias,weight}                                             | (336,) (336,88,1,1)                |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_2.dw.conv.weight  | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_2.dw.conv.weight                                                    | (336, 1, 5, 5)                     |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_2.pw.bn.*         | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_2.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}  | (336,) () (336,) (336,) (336,)     |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_2.pw.conv.weight  | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_2.pw.conv.weight                                                    | (336, 112, 1, 1)                   |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_2.pwl.bn.*        | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_2.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight} | (112,) () (112,) (112,) (112,)     |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_2.pwl.conv.weight | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_2.pwl.conv.weight                                                   | (112, 336, 1, 1)                   |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_2.se.se.0.conv.*  | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_2.se.se.0.conv.{bias,weight}                                        | (88,) (88,336,1,1)                 |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_2.se.se.1.*       | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_2.se.se.1.{bias,weight}                                             | (336,) (336,88,1,1)                |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_3.dw.conv.weight  | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_3.dw.conv.weight                                                    | (336, 1, 5, 5)                     |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_3.pw.bn.*         | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_3.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}  | (336,) () (336,) (336,) (336,)     |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_3.pw.conv.weight  | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_3.pw.conv.weight                                                    | (336, 112, 1, 1)                   |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_3.pwl.bn.*        | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_3.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight} | (112,) () (112,) (112,) (112,)     |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_3.pwl.conv.weight | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_3.pwl.conv.weight                                                   | (112, 336, 1, 1)                   |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_3.se.se.0.conv.*  | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_3.se.se.0.conv.{bias,weight}                                        | (88,) (88,336,1,1)                 |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_3.se.se.1.*       | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_3.se.se.1.{bias,weight}                                             | (336,) (336,88,1,1)                |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_4.dw.conv.weight  | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_4.dw.conv.weight                                                    | (336, 1, 5, 5)                     |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_4.pw.bn.*         | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_4.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}  | (336,) () (336,) (336,) (336,)     |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_4.pw.conv.weight  | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_4.pw.conv.weight                                                    | (336, 112, 1, 1)                   |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_4.pwl.bn.*        | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_4.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight} | (112,) () (112,) (112,) (112,)     |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_4.pwl.conv.weight | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_4.pwl.conv.weight                                                   | (112, 336, 1, 1)                   |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_4.se.se.0.conv.*  | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_4.se.se.0.conv.{bias,weight}                                        | (88,) (88,336,1,1)                 |\n",
      "| proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_4.se.se.1.*       | proposal_generator.rpn_head.rpn_feature.0.fbnetv2_0_4.se.se.1.{bias,weight}                                             | (336,) (336,88,1,1)                |\n",
      "| proposal_generator.rpn_head.rpn_regressor.bbox_pred.*                 | proposal_generator.rpn_head.rpn_regressor.bbox_pred.{bias,weight}                                                       | (60,) (60,112,1,1)                 |\n",
      "| proposal_generator.rpn_head.rpn_regressor.cls_logits.*                | proposal_generator.rpn_head.rpn_regressor.cls_logits.{bias,weight}                                                      | (15,) (15,112,1,1)                 |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_0.dw.conv.weight          | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_0.dw.conv.weight                                                            | (448, 1, 5, 5)                     |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_0.pw.bn.*                 | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_0.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}          | (448,) () (448,) (448,) (448,)     |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_0.pw.conv.weight          | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_0.pw.conv.weight                                                            | (448, 112, 1, 1)                   |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_0.pwl.bn.*                | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_0.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}         | (184,) () (184,) (184,) (184,)     |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_0.pwl.conv.weight         | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_0.pwl.conv.weight                                                           | (184, 448, 1, 1)                   |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_0.se.se.0.conv.*          | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_0.se.se.0.conv.{bias,weight}                                                | (112,) (112,448,1,1)               |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_0.se.se.1.*               | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_0.se.se.1.{bias,weight}                                                     | (448,) (448,112,1,1)               |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_1.dw.conv.weight          | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_1.dw.conv.weight                                                            | (736, 1, 3, 3)                     |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_1.pw.bn.*                 | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_1.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}          | (736,) () (736,) (736,) (736,)     |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_1.pw.conv.weight          | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_1.pw.conv.weight                                                            | (736, 184, 1, 1)                   |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_1.pwl.bn.*                | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_1.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}         | (184,) () (184,) (184,) (184,)     |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_1.pwl.conv.weight         | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_1.pwl.conv.weight                                                           | (184, 736, 1, 1)                   |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_1.se.se.0.conv.*          | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_1.se.se.0.conv.{bias,weight}                                                | (184,) (184,736,1,1)               |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_1.se.se.1.*               | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_1.se.se.1.{bias,weight}                                                     | (736,) (736,184,1,1)               |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_2.dw.conv.weight          | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_2.dw.conv.weight                                                            | (736, 1, 3, 3)                     |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_2.pw.bn.*                 | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_2.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}          | (736,) () (736,) (736,) (736,)     |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_2.pw.conv.weight          | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_2.pw.conv.weight                                                            | (736, 184, 1, 1)                   |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_2.pwl.bn.*                | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_2.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}         | (184,) () (184,) (184,) (184,)     |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_2.pwl.conv.weight         | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_2.pwl.conv.weight                                                           | (184, 736, 1, 1)                   |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_2.se.se.0.conv.*          | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_2.se.se.0.conv.{bias,weight}                                                | (184,) (184,736,1,1)               |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_2.se.se.1.*               | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_2.se.se.1.{bias,weight}                                                     | (736,) (736,184,1,1)               |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_3.dw.conv.weight          | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_3.dw.conv.weight                                                            | (736, 1, 3, 3)                     |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_3.pw.bn.*                 | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_3.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}          | (736,) () (736,) (736,) (736,)     |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_3.pw.conv.weight          | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_3.pw.conv.weight                                                            | (736, 184, 1, 1)                   |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_3.pwl.bn.*                | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_3.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}         | (184,) () (184,) (184,) (184,)     |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_3.pwl.conv.weight         | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_3.pwl.conv.weight                                                           | (184, 736, 1, 1)                   |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_3.se.se.0.conv.*          | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_3.se.se.0.conv.{bias,weight}                                                | (184,) (184,736,1,1)               |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_3.se.se.1.*               | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_3.se.se.1.{bias,weight}                                                     | (736,) (736,184,1,1)               |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_4.dw.conv.weight          | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_4.dw.conv.weight                                                            | (736, 1, 3, 3)                     |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_4.pw.bn.*                 | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_4.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}          | (736,) () (736,) (736,) (736,)     |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_4.pw.conv.weight          | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_4.pw.conv.weight                                                            | (736, 184, 1, 1)                   |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_4.pwl.bn.*                | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_4.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}         | (184,) () (184,) (184,) (184,)     |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_4.pwl.conv.weight         | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_4.pwl.conv.weight                                                           | (184, 736, 1, 1)                   |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_4.se.se.0.conv.*          | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_4.se.se.0.conv.{bias,weight}                                                | (184,) (184,736,1,1)               |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_4.se.se.1.*               | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_4.se.se.1.{bias,weight}                                                     | (736,) (736,184,1,1)               |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_5.dw.conv.weight          | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_5.dw.conv.weight                                                            | (1104, 1, 5, 5)                    |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_5.pw.bn.*                 | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_5.pw.bn.{bias,num_batches_tracked,running_mean,running_var,weight}          | (1104,) () (1104,) (1104,) (1104,) |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_5.pw.conv.weight          | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_5.pw.conv.weight                                                            | (1104, 184, 1, 1)                  |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_5.pwl.bn.*                | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_5.pwl.bn.{bias,num_batches_tracked,running_mean,running_var,weight}         | (200,) () (200,) (200,) (200,)     |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_5.pwl.conv.weight         | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_5.pwl.conv.weight                                                           | (200, 1104, 1, 1)                  |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_5.se.se.0.conv.*          | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_5.se.se.0.conv.{bias,weight}                                                | (280,) (280,1104,1,1)              |\n",
      "| roi_heads.box_head.roi_box_conv.0.fbnetv2_0_5.se.se.1.*               | roi_heads.box_head.roi_box_conv.0.fbnetv2_0_5.se.se.1.{bias,weight}                                                     | (1104,) (1104,280,1,1)             |\n",
      "| roi_heads.box_predictor.bbox_pred.*                                   | roi_heads.box_predictor.bbox_pred.{bias,weight}                                                                         | (320,) (320,200)                   |\n",
      "| roi_heads.box_predictor.cls_score.*                                   | roi_heads.box_predictor.cls_score.{bias,weight}                                                                         | (81,) (81,200)                     |\n",
      "WARNING:fvcore.common.checkpoint:The checkpoint state_dict contains keys that are not used by the model:\n",
      "  \u001b[35mroi_heads.mask_head.feature_extractor.0.fbnetv2_0_0.dw.conv.weight\u001b[0m\n",
      "  \u001b[35mroi_heads.mask_head.feature_extractor.0.fbnetv2_0_0.pw.bn.{bias, num_batches_tracked, running_mean, running_var, weight}\u001b[0m\n",
      "  \u001b[35mroi_heads.mask_head.feature_extractor.0.fbnetv2_0_0.pw.conv.weight\u001b[0m\n",
      "  \u001b[35mroi_heads.mask_head.feature_extractor.0.fbnetv2_0_0.pwl.bn.{bias, num_batches_tracked, running_mean, running_var, weight}\u001b[0m\n",
      "  \u001b[35mroi_heads.mask_head.feature_extractor.0.fbnetv2_0_0.pwl.conv.weight\u001b[0m\n",
      "  \u001b[35mroi_heads.mask_head.feature_extractor.0.fbnetv2_0_1.dw.conv.weight\u001b[0m\n",
      "  \u001b[35mroi_heads.mask_head.feature_extractor.0.fbnetv2_0_1.pw.bn.{bias, num_batches_tracked, running_mean, running_var, weight}\u001b[0m\n",
      "  \u001b[35mroi_heads.mask_head.feature_extractor.0.fbnetv2_0_1.pw.conv.weight\u001b[0m\n",
      "  \u001b[35mroi_heads.mask_head.feature_extractor.0.fbnetv2_0_1.pwl.bn.{bias, num_batches_tracked, running_mean, running_var, weight}\u001b[0m\n",
      "  \u001b[35mroi_heads.mask_head.feature_extractor.0.fbnetv2_0_1.pwl.conv.weight\u001b[0m\n",
      "  \u001b[35mroi_heads.mask_head.feature_extractor.0.fbnetv2_0_2.dw.conv.weight\u001b[0m\n",
      "  \u001b[35mroi_heads.mask_head.feature_extractor.0.fbnetv2_0_2.pw.bn.{bias, num_batches_tracked, running_mean, running_var, weight}\u001b[0m\n",
      "  \u001b[35mroi_heads.mask_head.feature_extractor.0.fbnetv2_0_2.pw.conv.weight\u001b[0m\n",
      "  \u001b[35mroi_heads.mask_head.feature_extractor.0.fbnetv2_0_2.pwl.bn.{bias, num_batches_tracked, running_mean, running_var, weight}\u001b[0m\n",
      "  \u001b[35mroi_heads.mask_head.feature_extractor.0.fbnetv2_0_2.pwl.conv.weight\u001b[0m\n",
      "  \u001b[35mroi_heads.mask_head.feature_extractor.0.fbnetv2_0_3.dw.conv.weight\u001b[0m\n",
      "  \u001b[35mroi_heads.mask_head.feature_extractor.0.fbnetv2_0_3.pw.bn.{bias, num_batches_tracked, running_mean, running_var, weight}\u001b[0m\n",
      "  \u001b[35mroi_heads.mask_head.feature_extractor.0.fbnetv2_0_3.pw.conv.weight\u001b[0m\n",
      "  \u001b[35mroi_heads.mask_head.feature_extractor.0.fbnetv2_0_3.pwl.bn.{bias, num_batches_tracked, running_mean, running_var, weight}\u001b[0m\n",
      "  \u001b[35mroi_heads.mask_head.feature_extractor.0.fbnetv2_0_3.pwl.conv.weight\u001b[0m\n",
      "  \u001b[35mroi_heads.mask_head.feature_extractor.0.fbnetv2_0_4.dw.conv.weight\u001b[0m\n",
      "  \u001b[35mroi_heads.mask_head.feature_extractor.0.fbnetv2_0_4.pw.bn.{bias, num_batches_tracked, running_mean, running_var, weight}\u001b[0m\n",
      "  \u001b[35mroi_heads.mask_head.feature_extractor.0.fbnetv2_0_4.pw.conv.weight\u001b[0m\n",
      "  \u001b[35mroi_heads.mask_head.feature_extractor.0.fbnetv2_0_4.pwl.bn.{bias, num_batches_tracked, running_mean, running_var, weight}\u001b[0m\n",
      "  \u001b[35mroi_heads.mask_head.feature_extractor.0.fbnetv2_0_4.pwl.conv.weight\u001b[0m\n",
      "  \u001b[35mroi_heads.mask_head.predictor.mask_fcn_logits.{bias, weight}\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "from d2go.model_zoo import model_zoo\n",
    "model = model_zoo.get('faster_rcnn_fbnetv3a_C4.yaml', trained=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Download an image from the COCO dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "from matplotlib import pyplot as plt\n",
    "!wget http://images.cocodataset.org/val2017/000000439715.jpg -q -O input.jpg\n",
    "im = cv2.imread(\"./input.jpg\")\n",
    "plt.imshow(im)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Then we can create a `DemoPredictor` to run inference on this image and see the raw outputs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from d2go.utils.demo_predictor import DemoPredictor\n",
    "predictor = DemoPredictor(model)\n",
    "outputs = predictor(im)\n",
    "# the output object categories and corresponding bounding boxes\n",
    "print(outputs[\"instances\"].pred_classes)\n",
    "print(outputs[\"instances\"].pred_boxes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Let's visualize the output predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from detectron2.utils.visualizer import Visualizer\n",
    "from detectron2.data import MetadataCatalog, DatasetCatalog\n",
    "\n",
    "v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(\"coco_2017_train\"))\n",
    "out = v.draw_instance_predictions(outputs[\"instances\"].to(\"cpu\"))\n",
    "plt.imshow(out.get_image()[:, :, ::-1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train on a custom dataset\n",
    "In this section, we show how to train a d2go model on a custom dataset.\n",
    "\n",
    "We use [the balloon segmentation dataset](https://github.com/matterport/Mask_RCNN/tree/master/samples/balloon)\n",
    "which only has one class: balloon.\n",
    "We'll train a balloon segmentation model from an existing model pre-trained on COCO dataset, available in d2go's [model zoo](https://github.com/facebookresearch/d2go/blob/master/MODEL_ZOO.md).\n",
    "\n",
    "### Prepare the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# download, decompress the data\n",
    "!wget https://github.com/matterport/Mask_RCNN/releases/download/v2.1/balloon_dataset.zip\n",
    "!unzip -o balloon_dataset.zip > /dev/null"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "D2Go is built on top of detectron2. Let's register the balloon dataset to detectron2, following the [detectron2 custom dataset tutorial](https://detectron2.readthedocs.io/tutorials/datasets.html).\n",
    "Here, the dataset is in its custom format, therefore we write a function to parse it and prepare it into detectron2's standard format. User should write such a function when using a dataset in custom format. See the tutorial for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if your dataset is in COCO format, this cell can be replaced by the following three lines:\n",
    "# from detectron2.data.datasets import register_coco_instances\n",
    "# register_coco_instances(\"my_dataset_train\", {}, \"json_annotation_train.json\", \"path/to/image/dir\")\n",
    "# register_coco_instances(\"my_dataset_val\", {}, \"json_annotation_val.json\", \"path/to/image/dir\")\n",
    "import os\n",
    "import json\n",
    "import numpy as np\n",
    "from detectron2.structures import BoxMode\n",
    "\n",
    "def get_balloon_dicts(img_dir):\n",
    "    json_file = os.path.join(img_dir, \"via_region_data.json\")\n",
    "    with open(json_file) as f:\n",
    "        imgs_anns = json.load(f)\n",
    "\n",
    "    dataset_dicts = []\n",
    "    for idx, v in enumerate(imgs_anns.values()):\n",
    "        record = {}\n",
    "        \n",
    "        filename = os.path.join(img_dir, v[\"filename\"])\n",
    "        height, width = cv2.imread(filename).shape[:2]\n",
    "        \n",
    "        record[\"file_name\"] = filename\n",
    "        record[\"image_id\"] = idx\n",
    "        record[\"height\"] = height\n",
    "        record[\"width\"] = width\n",
    "      \n",
    "        annos = v[\"regions\"]\n",
    "        objs = []\n",
    "        for _, anno in annos.items():\n",
    "            assert not anno[\"region_attributes\"]\n",
    "            anno = anno[\"shape_attributes\"]\n",
    "            px = anno[\"all_points_x\"]\n",
    "            py = anno[\"all_points_y\"]\n",
    "            poly = [(x + 0.5, y + 0.5) for x, y in zip(px, py)]\n",
    "            poly = [p for x in poly for p in x]\n",
    "\n",
    "            obj = {\n",
    "                \"bbox\": [np.min(px), np.min(py), np.max(px), np.max(py)],\n",
    "                \"bbox_mode\": BoxMode.XYXY_ABS,\n",
    "                \"segmentation\": [poly],\n",
    "                \"category_id\": 0,\n",
    "            }\n",
    "            objs.append(obj)\n",
    "        record[\"annotations\"] = objs\n",
    "        dataset_dicts.append(record)\n",
    "    return dataset_dicts\n",
    "\n",
    "for d in [\"train\", \"val\"]:\n",
    "    DatasetCatalog.register(\"balloon_\" + d, lambda d=d: get_balloon_dicts(\"balloon/\" + d))\n",
    "    MetadataCatalog.get(\"balloon_\" + d).set(thing_classes=[\"balloon\"], evaluator_type=\"coco\")\n",
    "balloon_metadata = MetadataCatalog.get(\"balloon_train\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To verify the data loading is correct, let's visualize the annotations of randomly selected samples in the training set:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "dataset_dicts = get_balloon_dicts(\"balloon/train\")\n",
    "for d in random.sample(dataset_dicts, 3):\n",
    "    img = cv2.imread(d[\"file_name\"])\n",
    "    visualizer = Visualizer(img[:, :, ::-1], metadata=balloon_metadata, scale=0.5)\n",
    "    out = visualizer.draw_dataset_dict(d)\n",
    "    plt.figure()\n",
    "    plt.imshow(out.get_image()[:, :, ::-1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train\n",
    "Now, let's fine-tune a COCO-pretrained FBNetV3A Mask R-CNN model on the balloon dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for d in [\"train\", \"val\"]:\n",
    "    MetadataCatalog.get(\"balloon_\" + d).set(thing_classes=[\"balloon\"], evaluator_type=\"coco\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from d2go.runner import GeneralizedRCNNRunner\n",
    "\n",
    "\n",
    "def prepare_for_launch():\n",
    "    runner = GeneralizedRCNNRunner()\n",
    "    cfg = runner.get_default_cfg()\n",
    "    cfg.merge_from_file(model_zoo.get_config_file(\"faster_rcnn_fbnetv3a_C4.yaml\"))\n",
    "    cfg.MODEL_EMA.ENABLED = False\n",
    "    cfg.DATASETS.TRAIN = (\"balloon_train\",)\n",
    "    cfg.DATASETS.TEST = (\"balloon_val\",)\n",
    "    cfg.DATALOADER.NUM_WORKERS = 2\n",
    "    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(\"faster_rcnn_fbnetv3a_C4.yaml\")  # Let training initialize from model zoo\n",
    "    cfg.MODEL.DEVICE = \"cpu\" if ('CI' in os.environ) else \"cuda\"\n",
    "    cfg.SOLVER.IMS_PER_BATCH = 2\n",
    "    cfg.SOLVER.BASE_LR = 0.00025  # pick a good LR\n",
    "    cfg.SOLVER.MAX_ITER = 5 if ('CI' in os.environ) else 600    # 600 iterations seems good enough for this toy dataset; you will need to train longer for a practical dataset\n",
    "    cfg.SOLVER.STEPS = []        # do not decay learning rate\n",
    "    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128   # faster, and good enough for this toy dataset (default: 512)\n",
    "    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  # only has one class (ballon). (see https://detectron2.readthedocs.io/tutorials/datasets.html#update-the-config-for-new-datasets)\n",
    "    # NOTE: this config means the number of classes, but a few popular unofficial tutorials incorrect uses num_classes+1 here.\n",
    "    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)\n",
    "    return cfg, runner\n",
    "\n",
    "cfg, runner = prepare_for_launch()\n",
    "model = runner.build_model(cfg)\n",
    "runner.do_train(cfg, model, resume=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference & evaluation using the trained model\n",
    "Now, let's run inference with the trained model on the balloon validation dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics = runner.do_test(cfg, model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The evaluation results are"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export to Int8 Model\n",
    "This section export int8 models using post-training quantization. For quantization-aware training, please see the [instructions](https://github.com/facebookresearch/d2go/tree/master/demo#quantization-aware-training)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "from detectron2.data import build_detection_test_loader\n",
    "from d2go.export.exporter import convert_and_export_predictor\n",
    "from d2go.utils.testing.data_loader_helper import create_detection_data_loader_on_toy_dataset\n",
    "\n",
    "import logging\n",
    "\n",
    "# disable all the warnings\n",
    "previous_level = logging.root.manager.disable\n",
    "logging.disable(logging.INFO)\n",
    "\n",
    "cfg_name = 'faster_rcnn_fbnetv3a_dsmask_C4.yaml'\n",
    "pytorch_model = model_zoo.get(cfg_name, trained=True, device='cpu')\n",
    "pytorch_model.eval()\n",
    "cfg = model_zoo.get_config(cfg_name)\n",
    "\n",
    "with create_detection_data_loader_on_toy_dataset(cfg, 224, 320, is_train=False) as data_loader:\n",
    "    predictor_path = convert_and_export_predictor(\n",
    "            cfg,\n",
    "            copy.deepcopy(pytorch_model),\n",
    "            \"torchscript_int8\",\n",
    "            './',\n",
    "            data_loader,\n",
    "        )\n",
    "\n",
    "# recover the logging level\n",
    "logging.disable(previous_level)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create the predictor using the exported int8 model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mobile_cv.predictor.api import create_predictor\n",
    "model = create_predictor(predictor_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Make predictions and Visualize the output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from d2go.utils.demo_predictor import DemoPredictor\n",
    "predictor = DemoPredictor(model)\n",
    "outputs = predictor(im)\n",
    "\n",
    "v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(\"coco_2017_train\"))\n",
    "out = v.draw_instance_predictions(outputs[\"instances\"].to(\"cpu\"))\n",
    "plt.imshow(out.get_image()[:, :, ::-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "517234ba8a8c2628fc901a4b482ee7ad83e05e3ca55f6d4e216a0b65fa2f59e9"
  },
  "kernelspec": {
   "display_name": "Python 3.9.13 ('d2go-x9zKx9Ui')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
